"""
Phase 2: Varying-Coefficient Estimation
========================================
For each DV x moderator combination, estimate:
    Y = Z₁ + Z₂ + Z₃ + Z₁×Moderator + Z₂×Moderator + Z₃×Moderator

This produces the "conditional demographic effect" — how Z's coefficient
changes across institutional regimes, income levels, and policy states.

Moderators: safe_issuer, qe_active, eurozone, income_high, income_low,
            is_oecd, kaopen_saturated

DVs: ca_gdp, gross_savings_gdp, gross_investment_gdp, govt_bond_10y,
     nfa_gdp, rer_index
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

# ── Configuration ─────────────────────────────────────────────────────

DVS = ['ca_gdp', 'gross_savings_gdp', 'gross_investment_gdp',
       'govt_bond_10y', 'nfa_gdp', 'rer_index']

# Moderator label -> interaction suffix in column names
MODERATORS = {
    'safe_issuer': 'safe',
    'qe_active': 'qe',
    'eurozone': 'emu',
    'income_high': 'high',
    'income_low': 'low',
    'is_oecd': 'oecd',
    'kaopen_sat': 'ksat',
    'oadr_15': 'oadr15',
    'oadr_20': 'oadr20',
    'oadr_25': 'oadr25',
}

Z_VARS = ['Z_1', 'Z_2', 'Z_3']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, p):
    return f"{val:.3f}{stars(p)}"


def run_gls(df, y_var, x_vars):
    """Run PanelGLS and return results dict or None."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        return None
    gls = PanelGLS()
    try:
        gls.fit(sub[y_var].values, sub[x_vars].values,
                sub['iso3'].values, sub['year'].values)
    except Exception:
        return None
    result = {'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
              'r_squared': gls.r_squared}
    for i, var in enumerate(x_vars):
        result[f'{var}_coef'] = gls.beta[i]
        result[f'{var}_se'] = gls.se[i]
        result[f'{var}_p'] = gls.pvalues[i]
    return result


# ── Main ──────────────────────────────────────────────────────────────

def main():
    print("Phase 2: Varying-Coefficient Estimation")
    print("=" * 70)

    df = pd.read_csv(DATA / "unified_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries\n")

    # Storage for all results
    all_results = []

    for dv in DVS:
        if dv not in df.columns:
            print(f"  {dv}: not in panel, skipping")
            continue

        print(f"\nDependent variable: {dv}")
        print("-" * 50)

        # Baseline: Z only (no moderator)
        baseline = run_gls(df, dv, Z_VARS)
        if baseline is None:
            print(f"  Baseline failed, skipping {dv}")
            continue

        print(f"  Baseline: N={baseline['n_obs']}, R²={baseline['r_squared']:.3f}")
        for z in Z_VARS:
            c, p = baseline[f'{z}_coef'], baseline[f'{z}_p']
            print(f"    {z}: {fmt(c, p)}")

        all_results.append({
            'dv': dv, 'moderator': '(baseline)', 'n_obs': baseline['n_obs'],
            'n_countries': baseline['n_countries'], 'r2': baseline['r_squared'],
            **{f'{z}_coef': baseline[f'{z}_coef'] for z in Z_VARS},
            **{f'{z}_p': baseline[f'{z}_p'] for z in Z_VARS},
            **{f'Z1x_coef': np.nan, f'Z1x_p': np.nan,
               f'Z2x_coef': np.nan, f'Z2x_p': np.nan,
               f'Z3x_coef': np.nan, f'Z3x_p': np.nan},
        })

        # Varying-coefficient: Z + Z×Moderator
        for mod_label, suffix in MODERATORS.items():
            int_vars = [f'Z_1_x_{suffix}', f'Z_2_x_{suffix}', f'Z_3_x_{suffix}']
            missing = [v for v in int_vars if v not in df.columns]
            if missing:
                continue

            x_vars = Z_VARS + int_vars
            res = run_gls(df, dv, x_vars)
            if res is None:
                print(f"  {mod_label}: insufficient data")
                continue

            z1x_c = res[f'{int_vars[0]}_coef']
            z1x_p = res[f'{int_vars[0]}_p']
            print(f"  {mod_label}: Z₁×mod={fmt(z1x_c, z1x_p)}, "
                  f"N={res['n_obs']}, R²={res['r_squared']:.3f}")

            all_results.append({
                'dv': dv, 'moderator': mod_label,
                'n_obs': res['n_obs'], 'n_countries': res['n_countries'],
                'r2': res['r_squared'],
                **{f'{z}_coef': res[f'{z}_coef'] for z in Z_VARS},
                **{f'{z}_p': res[f'{z}_p'] for z in Z_VARS},
                'Z1x_coef': res[f'{int_vars[0]}_coef'],
                'Z1x_p': res[f'{int_vars[0]}_p'],
                'Z2x_coef': res[f'{int_vars[1]}_coef'],
                'Z2x_p': res[f'{int_vars[1]}_p'],
                'Z3x_coef': res[f'{int_vars[2]}_coef'],
                'Z3x_p': res[f'{int_vars[2]}_p'],
            })

    # ── Write results tables ──────────────────────────────────────────

    results_df = pd.DataFrame(all_results)
    results_df.to_csv(DATA / "phase2_varying_coefficients.csv", index=False)

    # Table A: Main effects summary (Z₁ focus)
    print(f"\n{'=' * 70}")
    print("Writing output tables...")

    with open(OUT_TABLES / "phase2_main_effects.md", 'w') as f:
        f.write("# Phase 2: Varying-Coefficient Estimates (Z₁ Focus)\n\n")
        f.write("| DV | Moderator | Z₁ (base) | Z₁×Mod | N | R² |\n")
        f.write("|---|---|---|---|---|---|\n")
        for _, row in results_df.iterrows():
            z1_str = fmt(row['Z_1_coef'], row['Z_1_p'])
            if row['moderator'] == '(baseline)':
                zx_str = '--'
            else:
                zx_str = fmt(row['Z1x_coef'], row['Z1x_p'])
            f.write(f"| {row['dv']} | {row['moderator']} | "
                    f"{z1_str} | {zx_str} | "
                    f"{int(row['n_obs'])} | {row['r2']:.3f} |\n")
        f.write("\n*PanelGLS with AR(1) correction. "
                "\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*\n")
    print("  Wrote: phase2_main_effects.md")

    # Table B: Significant interactions only
    sig = results_df[
        (results_df['moderator'] != '(baseline)') &
        (results_df['Z1x_p'].notna()) &
        (results_df['Z1x_p'] < 0.10)
    ].copy()
    sig = sig.sort_values('Z1x_p')

    with open(OUT_TABLES / "phase2_significant_interactions.md", 'w') as f:
        f.write("# Phase 2: Significant Z₁×Moderator Interactions (p<0.10)\n\n")
        f.write("| DV | Moderator | Z₁ (base) | Z₁×Mod | Total Effect | N |\n")
        f.write("|---|---|---|---|---|---|\n")
        for _, row in sig.iterrows():
            z1_str = fmt(row['Z_1_coef'], row['Z_1_p'])
            zx_str = fmt(row['Z1x_coef'], row['Z1x_p'])
            total = row['Z_1_coef'] + row['Z1x_coef']
            f.write(f"| {row['dv']} | {row['moderator']} | "
                    f"{z1_str} | {zx_str} | {total:.3f} | "
                    f"{int(row['n_obs'])} |\n")
        f.write(f"\n*{len(sig)} significant interactions out of "
                f"{len(results_df[results_df['moderator'] != '(baseline)'])} tested.*\n")
        f.write("*Total Effect = Z₁ base + Z₁×Mod for countries where Mod=1.*\n")
    print("  Wrote: phase2_significant_interactions.md")

    # Table C: OECD null diagnosis
    oecd_rows = results_df[results_df['moderator'].isin(['(baseline)', 'is_oecd'])].copy()
    with open(OUT_TABLES / "phase2_oecd_diagnosis.md", 'w') as f:
        f.write("# Phase 2: OECD Null Diagnosis\n\n")
        f.write("Does adding OECD interaction explain why Z₁ weakens in OECD samples?\n\n")
        f.write("| DV | Model | Z₁ | Z₁×OECD | Total (OECD) | N |\n")
        f.write("|---|---|---|---|---|---|\n")
        for dv in DVS:
            for _, row in oecd_rows[oecd_rows['dv'] == dv].iterrows():
                z1_str = fmt(row['Z_1_coef'], row['Z_1_p'])
                if row['moderator'] == '(baseline)':
                    f.write(f"| {dv} | Baseline | {z1_str} | -- | -- | "
                            f"{int(row['n_obs'])} |\n")
                else:
                    zx_str = fmt(row['Z1x_coef'], row['Z1x_p'])
                    total = row['Z_1_coef'] + row['Z1x_coef']
                    f.write(f"| {dv} | +OECD int. | {z1_str} | {zx_str} | "
                            f"{total:.3f} | {int(row['n_obs'])} |\n")
        f.write("\n*If Z₁×OECD is negative and significant, OECD membership attenuates "
                "the demographic effect.*\n")
    print("  Wrote: phase2_oecd_diagnosis.md")

    # Summary statistics
    n_sig_01 = len(results_df[(results_df['moderator'] != '(baseline)') &
                               (results_df['Z1x_p'] < 0.01)])
    n_sig_05 = len(results_df[(results_df['moderator'] != '(baseline)') &
                               (results_df['Z1x_p'] < 0.05)])
    n_sig_10 = len(results_df[(results_df['moderator'] != '(baseline)') &
                               (results_df['Z1x_p'] < 0.10)])
    n_total = len(results_df[results_df['moderator'] != '(baseline)'])

    print(f"\n{'=' * 70}")
    print(f"Significant Z₁×Moderator interactions:")
    print(f"  p<0.01: {n_sig_01}/{n_total}")
    print(f"  p<0.05: {n_sig_05}/{n_total}")
    print(f"  p<0.10: {n_sig_10}/{n_total}")
    print(f"\nPhase 2 complete.")


if __name__ == '__main__':
    main()
